{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add `src` directory to path (change this to your `src` directory)\n",
    "import sys\n",
    "sys.path.insert(0, \"/home/cody/abcnn/ABCNN/src\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import `standard` modules\n",
    "import os\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "# Import custom modules\n",
    "from setup import read_config\n",
    "from setup import setup\n",
    "from trainer.factories import loss_fn_factory\n",
    "from trainer.factories import optimizer_factory\n",
    "from trainer.factories import scheduler_factory\n",
    "from trainer.multiclass_classifier_trainer import MulticlassClassifierTrainer\n",
    "from trainer.utils import move_to_device\n",
    "from utils import abcnn_model_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the input values\n",
    "CONFIG_PATH = \"/home/cody/abcnn/ABCNN/src/config.yaml\"# path to configuration file\n",
    "TRAINSET = \"moveworks_train\"# name of training set (should be a key in data_paths)\n",
    "VALSET = \"moveworks_val\"# name of validation set (should be a key in data_paths)\n",
    "TESTSET = \"moveworks_test\"# name of test set (should be a key in data_paths)\n",
    "LOAD_PATH = None # load a model from a checkpoint file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sanity check input values\n",
    "assert(os.path.isfile(CONFIG_PATH))\n",
    "assert(LOAD_PATH is None or os.path.isfile(LOAD_PATH))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "moveworks_train: 100%|██████████| 7868/7868 [00:02<00:00, 2742.42it/s]\n",
      "moveworks_val: 100%|██████████| 438/438 [00:00<00:00, 2770.56it/s]\n",
      "moveworks_test: 100%|██████████| 437/437 [00:00<00:00, 2770.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading FastText word vectors from: /home/cody/abcnn/embeddings/fasttext/tickets/word_vector_from_tickets_skipgram_dim300_subword_min2_max6.bin\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "embedding matrix: 100%|██████████| 2666/2666 [00:00<00:00, 124026.87it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating the ABCNN model...\n"
     ]
    },
    {
     "ename": "KeyError",
     "evalue": "'environment'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-de1f0e4bd0f0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptimizer_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"optimizer\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0mscheduler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscheduler_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"scheduler\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMulticlassClassifierTrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"trainer\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     13\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mLOAD_PATH\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m     \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mabcnn_model_loader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mLOAD\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/abcnn/ABCNN/src/trainer/multiclass_classifier_trainer.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config)\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     66\u001b[0m         \u001b[0;31m# Hacky way to get tqdm to work in the shell and in jupyter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 67\u001b[0;31m         \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"environment\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"script\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     68\u001b[0m             \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     69\u001b[0m             \u001b[0;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrange\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyError\u001b[0m: 'environment'"
     ]
    }
   ],
   "source": [
    "# Setup modules\n",
    "config = read_config(CONFIG_PATH)\n",
    "features, labels, model = setup(config[\"model\"])\n",
    "model = move_to_device(config[\"trainer\"][\"device\"], model) # hacky, but necessary for trainer\n",
    "datasets = {\n",
    "    name: TensorDataset(features[name], labels[name])\n",
    "    for name in features\n",
    "}\n",
    "loss_fn = loss_fn_factory(config[\"loss_fn\"])\n",
    "optimizer = optimizer_factory(config[\"optimizer\"], model.parameters())\n",
    "scheduler = scheduler_factory(config[\"scheduler\"], optimizer)\n",
    "trainer = MulticlassClassifierTrainer(config[\"trainer\"])\n",
    "if LOAD_PATH:\n",
    "    model, optimizer = abcnn_model_loader(LOAD, model, optimizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainset = datasets[TRAINSET]\n",
    "valset = datasets[VALSET] if VALSET else None\n",
    "trainer.train(loss_fn, model, optimizer, trainset, scheduler=scheduler, valset=valset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "testset = datasets[TESTSET]\n",
    "trainer.predict(testset)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
